-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PJRT C API create_options
#6289
Conversation
Our resnet example actually works on 4 v100s with the plugin now! |
|
"allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default", | ||
"memory_fraction": xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75), | ||
"preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these env var new?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These all exist in env_vars.cc/h
already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to add any test for GPU plugin?
I don't actually have a good build process set up yet for this plugin. I'll work on that this week. Unless we find an issue, I'm inclined to just move the whole CI over to use the plugin once the build is automatable. WDYT? |
return { | ||
"platform_name": "gpu", | ||
# TODO(wcromar): make this configurable | ||
"allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for these 3 settings, is it possible not to use the hardcoded default settings: False, .75, True, such as
xla/torch_xla/csrc/runtime/pjrt_registry.cc
Lines 22 to 27 in 4bf8d44
auto allocator_config = xla::GpuAllocatorConfig{}; | |
if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && | |
sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && | |
sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { | |
return allocator_config; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Removed some options entirely when the environment variable is not set.
@will-cromar How will the UX be(for both building from source and installing from whl)? I want to make sure we don't require user to do additional steps to use the plug in after we make it default. |
For installing, we can add an extra requirement like we have for For building, you'll may have to build both the plugin and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: it may be possible to directly use the torch.distributed store here directly instead.
We briefly chatted about this offline, but perhaps we should push XlaCoordinator's distributed KV store up into the Python layer to replace TCPStore in our XLA backend implementation instead of dropping XlaCoordinator. Autocheckpointing will still require XlaCoordinator even if we move to torch.distributed
's kv store.
Just raising this for discussion. Overall this looks great, thanks Will!
struct PluginEntry { | ||
std::string library_path; | ||
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options; | ||
bool init_coordinator; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should drop init_coordinator
and instead always initialize the coordinator when the distributed env vars are set, even on TPU where it's not strictly necessary. As long as torchrun launches the training in a distributed context, the env vars should be set, which I believe covers all GPU use cases since we plan to use torchrun for GPU SPMD (cc @vanbasten23).
On TPU it's not currently required, but if that changes we can always detect the environment from the GCE metadata and set the env vars automatically for the user in a distributed context, since we don't require torchrun for multicontroller execution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should drop init_coordinator and instead always initialize the coordinator when the distributed env vars are set, even on TPU where it's not strictly necessary.
In this case, I think we still want to keep requires_xla_coordinator
option. We would just throw an error immediately if we don't have enough information to start the coordinator.
My other idea initially was that we could ask the plugin for the master IP, local rank, global rank, and world size, perhaps just asking torch.distributed
for those values in the default implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - that makes sense. Just for context why I brought this up, JAX recently started requiring the coordinator to be initialized before the backend can be used (even on TPUs), but I'm not sure on the reason.
Keeping init_coordinator
sounds fine to me. Thanks Will!
When debugging the pjrt c api for gpu, right now I can change pytorch/xla/WORKSPACE to use a local open XLA, add |
The GPU plugin binary still uses the same bazel workspace, so any changes you make there will apply to both the plugin and |
Registering the plugins right away is going to be inevitably flaky, because they will register their create options immediately. This doesn't allow the user to change settings at all after I'll go learn how to define the plugin base class in C++ so the |
I started work on defining the interface in C++ in #6360. Originally, I wanted to wait to merge this PR until #6022 goes in, but that is apparently blocked by an XLA bug. What do you all think of merging this PR now to unblock other work? It does not alter the default behavior at all, so it will not be a breaking change. |
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr; | ||
if (plugin->init_coordinator) { | ||
int global_process_rank = sys_util::GetEnvInt("RANK", 0); | ||
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't we already get it in the plugin->create_options
in python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create_options
doesn't include the rank and world size. In the future, we should actually just be creating the XlaCoordinator
in Python and using it at torch.distributed
's Store
. See #6289 (review)
return int(os.getenv('GPU_NUM_DEVICES', '1')) | ||
return xu.getenv_as('GPU_NUM_DEVICES', int, 1) | ||
|
||
def client_create_options(self) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
who will call this client_create_options
method here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plugins.py
will call it during plugin registration in this PR. In the follow up, this will be called when the client is created.
@@ -10,4 +11,37 @@ def library_path(self) -> str: | |||
|
|||
def physical_chip_count(self) -> int: | |||
# TODO: default to actual device count | |||
return int(os.getenv('GPU_NUM_DEVICES', '1')) | |||
return xu.getenv_as('GPU_NUM_DEVICES', int, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps irrelevant to this PR but just want to confirm that the # TODO: default to actual device count
still holds, since GPU_NUM_DEVICES
is not always set and the default value may not be 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it's only used in run_multiprocess
. So it sound using GPU_NUM_DEVICES
preserves the current behavior. Looks good to me then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is the same as the current behavior. Ideally this should check the PCI device IDs like we do for TPUs.
return {k: v for k, v in options.items() if v is not None} | ||
|
||
def requires_xla_coordinator(self) -> bool: | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why it always return True. For single processing, probably we don't need the coordinator? So should it depend on whether it is single processing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in a previous draft, I caught this case in InitializePjrt
. But as @jonb377 says, a plugin may always require the coordinator. I'll go ahead and update this to return global_world_size > 1
.
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr; | ||
if (plugin->init_coordinator) { | ||
int global_process_rank = sys_util::GetEnvInt("RANK", 0); | ||
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't catch it earlier. What if the users start the single-host training in a non-torchrun way such as PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py
, then global_world_size
should default to local_world_size
? Similar reason for global_process_rank
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case is slightly wrong. The precedence should be $PJRT_LOCAL_PROCESS_COUNT
, then $WORLD_SIZE
, then 1
(default). I'll fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be WORLD_SIZE -> PJRT_LOCAL_PROCESS_COUNT or LOCAL_WORLD_SIZE -> 1?
If we use torchrun, then WORLD_SIZE will be set.
If we GPU_NUM_DEVICES=4 python3
, non-torchrun for single-host-multi-process, then WORLD_SIZE is not set and we rely on PJRT_LOCAL_PROCESS_COUNT or LOCAL_WORLD_SIZE
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're right. I tripped over this while testing as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
DevicePlugin
implementation toGetPjRtCApiClient
. Pybind does most of the hard work here, luckily.XlaCoordinator
before the client if required for client initialization.torch.distributed
store here directly instead.See #6242 for broader context